{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "from collections import defaultdict\n",
    "import openai\n",
    "import json\n",
    "import regex\n",
    "from typing import Dict, Any, Tuple\n",
    "from tqdm import tqdm "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = defaultdict(list)\n",
    "test_data = defaultdict(list)\n",
    "\n",
    "for i in range(5):\n",
    "    with jsonlines.open(f'train_{i}.jsonl') as f:\n",
    "        for line in f.iter():\n",
    "            train_data[i].append(line)\n",
    "\n",
    "    with jsonlines.open(f'test_{i}.jsonl') as f:\n",
    "        for line in f.iter():\n",
    "            test_data[i].append(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = 'your_api_key'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Must modify to your personal fine-tuned model!\n",
    "\n",
    "ft_model_4 = {\n",
    "    0: 'ft:gpt-4o-2024-08-06:personal:train0:AcZ1PhB4',\n",
    "    1: 'ft:gpt-4o-2024-08-06:personal:train1:AcZ2DyL1',\n",
    "    2: 'ft:gpt-4o-2024-08-06:personal:train2:AcZ3YS9m',\n",
    "}\n",
    "\n",
    "ft_model_3 = {\n",
    "    0: 'ft:gpt-3.5-turbo-1106:personal:train0:AcZ6UY49',\n",
    "    1: 'ft:gpt-3.5-turbo-1106:personal:train1:AcZ7idHm',\n",
    "    2: 'ft:gpt-3.5-turbo-1106:personal:train2:AcZ6f95m',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_dict = {\n",
    "    'days': 24,\n",
    "    'day': 24,\n",
    "    'weeks': 24*7,\n",
    "    'week': 24*7,\n",
    "    'hours': 1,\n",
    "    'hour': 1,\n",
    "    'h': 1,\n",
    "    'minute': 1/60,\n",
    "    'min': 1/60,\n",
    "    's': 1/3600\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answers(model: str, message: dict, temperature=0.0) -> tuple:\n",
    "    completion = openai.ChatCompletion.create(\n",
    "        model=model,\n",
    "        messages=message['messages'][:2],\n",
    "        temperature=temperature\n",
    "    )\n",
    "\n",
    "    prediction = json.loads(completion.choices[0].message['content'])\n",
    "    true = json.loads(message['messages'][2]['content'])\n",
    "\n",
    "    return true, prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2094it [1:01:04,  1.75s/it]\n"
     ]
    }
   ],
   "source": [
    "true_label, prediction_label = defaultdict(list), defaultdict(list)\n",
    "\n",
    "for i in range(1):\n",
    "    model = ft_model_3[i]\n",
    "    testset = test_data[i]\n",
    "\n",
    "    for j, message in tqdm(enumerate(testset), total=len(testset)):\n",
    "        try:\n",
    "            true, prediction = get_answers(model, message)\n",
    "        except Exception as e:\n",
    "            print (f'gpt3.5-{i}-{j} : ', e)\n",
    "            continue\n",
    "        else:\n",
    "            true_label[f'gpt3.5-{i}'].append(true)\n",
    "            prediction_label[f'gpt3.5-{i}'].append(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2094/2094 [1:20:43<00:00,  2.31s/it]  \n"
     ]
    }
   ],
   "source": [
    "for i in range(1):\n",
    "    model = ft_model_4[i]\n",
    "    testset = test_data[i]\n",
    "\n",
    "    for j, message in tqdm(enumerate(testset), total=len(testset)):\n",
    "        try:\n",
    "            true, prediction = get_answers(model, message)\n",
    "        except Exception as e:\n",
    "            print (f'gpt4-{i}-{j} : ', e)\n",
    "            continue\n",
    "        else:\n",
    "            true_label[f'gpt4-{i}'].append(true)\n",
    "            prediction_label[f'gpt4-{i}'].append(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('true_label_save.pickle', 'wb') as f:\n",
    "    pickle.dump(true_label, f)\n",
    "\n",
    "with open('prediction_label_save.pickle', 'wb') as f:\n",
    "    pickle.dump(prediction_label, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llmminer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
